{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a5eacac6-a7a4-4ed6-9de4-b589aaaa112c",
   "metadata": {},
   "source": [
    "# Extending\n",
    "\n",
    "You can extend FD-Shifts to benchmark your own models including new confidence scores, your own dataset and your own softmax-based confidence scoring functions. In this tutorial we will see how.\n",
    "\n",
    "If you want to run this locally you need to install optional dependencies:\n",
    "```bash\n",
    "pip install 'fd-shifts[docs] @ https://github.com/iml-dkfz/fd-shifts.git'\n",
    "```\n",
    "\n",
    "If you run this notebook in Google Colab, please execute the following cell, then reload the page (`Ctrl+R` or `Cmd+R`). If you want faster runtime you can switch to GPU acceleration in the Colab settings before exectuing this cell. Remember to switch the runtime to GPU again after reloading the page if you did that before executing this cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e4a7b52-102b-4bca-b51a-782dd99fe58b",
   "metadata": {
    "tags": [
     "skip-execution"
    ]
   },
   "outputs": [],
   "source": [
    "!wget https://github.com/korakot/kora/releases/download/v0.10/py310.sh\n",
    "!bash ./py310.sh -b -f -p /usr/local\n",
    "!python -m ipykernel install --name \"py310\" --user\n",
    "!pip install 'fd-shifts[docs] @ git+https://github.com/iml-dkfz/fd-shifts.git'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "315acabd-7c8c-4744-8ed0-047907b23139",
   "metadata": {},
   "source": [
    "## Some Setup\n",
    "\n",
    "First we have to import a lot of stuff and create a config object. Make sure to set `EXPERIMENT_ROOT_DIR` and `DATASET_ROOT_DIR` appropriately beforehand."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cbebc14",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "cwd = os.getcwd()\n",
    "fd_dir = os.path.abspath(os.path.join(cwd, os.pardir))\n",
    "os.environ[\"EXPERIMENT_ROOT_DIR\"] = fd_dir + \"/experiments_folder\"\n",
    "os.environ[\"DATASET_ROOT_DIR\"] = fd_dir + \"/data_folder\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01e966ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Callable, Optional\n",
    "from pathlib import Path\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar\n",
    "from rich import print\n",
    "from torch import nn\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "from fd_shifts import analysis, configs, models, reporting\n",
    "from fd_shifts.exec import test, train\n",
    "from fd_shifts.loaders import dataset_collection\n",
    "from fd_shifts.loaders.data_loader import FDShiftsDataLoader\n",
    "from fd_shifts.utils import exp_utils\n",
    "\n",
    "configs.init()\n",
    "config = configs.Config.with_defaults(data=\"svhn\")\n",
    "\n",
    "# Remove tinyimagenet, you can comment this out if you downloaded tinyimagenet_resize and put it\n",
    "# into $DATASET_ROOT_DIR\n",
    "config.eval.query_studies.new_class_study = [\"cifar10\", \"cifar100\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6742ca5d",
   "metadata": {},
   "source": [
    "## Adding a New Model Including a New Confidence Scoring Function\n",
    "\n",
    "Let's start with adding a new model. First we set up a model class inheriting from `LightningModule`. We have to create methods for `train`, `validation` and `test` steps for the benchmark to work. We will also copy over `load_only_state_dict`, a helper method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "363ebbe8-b0ad-4fdb-b6c1-49e2f75ca96a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class MyModel(pl.LightningModule):\n",
    "    def __init__(self, cfg: configs.Config) -> None:\n",
    "        super().__init__()\n",
    "\n",
    "        self.cfg = cfg\n",
    "        self.ext_confid_name = \"my_fancy_confidence\"\n",
    "\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(\n",
    "                cfg.data.img_size[0] * cfg.data.img_size[1] * cfg.data.img_size[2], 512\n",
    "            ),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 512),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "\n",
    "        self.classifier = nn.Linear(512, 10)\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        logits = self.classifier(self.encoder(x))\n",
    "        loss = torch.nn.functional.cross_entropy(logits, y)\n",
    "        return {\"loss\": loss, \"softmax\": torch.softmax(logits, dim=1), \"labels\": y}\n",
    "\n",
    "    def validation_step(self, batch, batch_idx, dataloader_idx=0):\n",
    "        x, y = batch\n",
    "\n",
    "        if x.shape[1] > 1 and self.cfg.data.img_size[2] == 1:\n",
    "            x = transforms.Grayscale()(x)\n",
    "\n",
    "        logits = self.classifier(self.encoder(x))\n",
    "        loss = torch.nn.functional.cross_entropy(logits, y)\n",
    "        my_fancy_confidence = torch.sum(logits, dim=1)\n",
    "        return {\n",
    "            \"loss\": loss,\n",
    "            \"softmax\": torch.softmax(logits, dim=1),\n",
    "            \"labels\": y,\n",
    "            \"confid\": my_fancy_confidence,\n",
    "        }\n",
    "\n",
    "    def test_step(self, batch, batch_idx, dataloader_idx=0):\n",
    "        x, y = batch\n",
    "\n",
    "        if x.shape[1] > 1 and self.cfg.data.img_size[2] == 1:\n",
    "            x = transforms.Grayscale()(x)\n",
    "\n",
    "        logits = self.classifier(self.encoder(x))\n",
    "        my_fancy_confidence = torch.sum(logits, dim=1)\n",
    "        self.test_results = {\n",
    "            \"logits\": logits,\n",
    "            \"labels\": y,\n",
    "            \"confid\": my_fancy_confidence,\n",
    "        }\n",
    "\n",
    "    def load_only_state_dict(self, path):\n",
    "        ckpt = torch.load(path)\n",
    "        self.load_state_dict(ckpt[\"state_dict\"], strict=True)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam(self.parameters())\n",
    "        return {\n",
    "            \"optimizer\": optimizer,\n",
    "            \"lr_scheduler\": torch.optim.lr_scheduler.CosineAnnealingLR(\n",
    "                optimizer, T_max=self.cfg.trainer.num_epochs\n",
    "            ),\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ee3b748-8356-443a-b64b-0818ece621bb",
   "metadata": {},
   "source": [
    "### Registering and Setting Up the Config\n",
    "\n",
    "To use the model we have to tell FD-Shifts about it and update our configuration. We will also use this opportunity to update the default configuration with some information about our experiment. Afterwards we can train the model. In this case we train on SVHN since that is the dataset for which we loaded the default configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb9c5581-d1c7-4b7c-908f-8cc2eb0a5f96",
   "metadata": {
    "nbmake": {
     "mock": {
      "config.trainer.batch_size": 64,
      "config.trainer.num_epochs": 5
     }
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "models.register_model(\"my_model\", MyModel)\n",
    "\n",
    "config.model.name = \"my_model\"\n",
    "config.eval.ext_confid_name = \"ext\"\n",
    "\n",
    "config.trainer.num_epochs = 10\n",
    "config.trainer.batch_size = 256\n",
    "\n",
    "name = \"my_first_experiment\"\n",
    "config = config.update_experiment(name)\n",
    "\n",
    "print(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6ba1041-91ba-4673-9f73-86befe843ae1",
   "metadata": {},
   "source": [
    "### Training the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71bed0da-ec2e-406b-b2af-633597377d89",
   "metadata": {},
   "outputs": [],
   "source": [
    "train(config, RichProgressBar())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a2ca5d8-1691-4d4f-9de3-307ed4dc5fed",
   "metadata": {},
   "source": [
    "### Testing the Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c29233bb-acfa-480e-ba04-8f2e2ecdaf68",
   "metadata": {},
   "source": [
    "Now let's test on SVHN as well as some additional datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7aaf9fe",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "test(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11a0b8d1-4c76-4a9d-bd75-0485a868e7b1",
   "metadata": {},
   "source": [
    "### Reporting Our Results\n",
    "\n",
    "All computed metrics are now found in various `csv` files in the experiment folder. Let's load and preprocess them and then render a benchmark table similar to the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f068913",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.concat(\n",
    "    [\n",
    "        reporting.load_file(p, experiment_override=str(p.parent.parent.parent.stem))\n",
    "        for p in Path(os.getenv(\"EXPERIMENT_ROOT_DIR\")).glob(\"**/test_results/*.csv\")\n",
    "    ]\n",
    ")\n",
    "# data = data.assign(experiment=config.exp.group_name)\n",
    "data = data.assign(study=data.experiment + \"_\" + data.study)\n",
    "data = reporting.assign_hparams_from_names(data)\n",
    "\n",
    "# data = reporting.filter_unused(data)\n",
    "data = reporting.rename_confids(data)\n",
    "data = reporting.rename_studies(data)\n",
    "data = reporting.tables.aggregate_over_runs(data)\n",
    "data = reporting.str_format_metrics(data)\n",
    "\n",
    "results_table = reporting.tables.build_results_table(\n",
    "    data=data, metric=\"aurc\", original_mode=False, paper_filter=False\n",
    ")\n",
    "results_table"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7aeb941-4283-48db-8a62-d4545e1be12f",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "## Adding a New Dataset\n",
    "\n",
    "We can also evaluate our model (or one of the built in ones) on a custom dataset. Let's define a dataset that is just a wrapper around MNIST for simplicity and tell FD-Shifts about it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cb51c25",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "class MyDataset(datasets.MNIST):\n",
    "    def __init__(\n",
    "        self,\n",
    "        root: str,\n",
    "        train: bool = True,\n",
    "        transform: Optional[Callable] = None,\n",
    "        target_transform: Optional[Callable] = None,\n",
    "        download: bool = False,\n",
    "    ) -> None:\n",
    "        super().__init__(root, train, transform, target_transform, download)\n",
    "\n",
    "\n",
    "dataset_collection.register_dataset(\"mydataset\", MyDataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "158dddce",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "874fd88b-3404-41ab-9be6-7adb4fde2338",
   "metadata": {},
   "source": [
    "### Updating the Configuration\n",
    "\n",
    "We now have to update our configuration with the new dataset. We will also update the list of datasets to additionally test on. Since we train on MNIST we could argue that SVHN provides a sub-class shift, while CIFAR-10 might be a good choice for a new-class shift."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7131d5fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "config_svhn = deepcopy(config)\n",
    "\n",
    "config.data = configs.DataConfig(\n",
    "    dataset=\"mydataset\",\n",
    "    data_dir=config.data.data_dir.parent / \"mydataset\",\n",
    "    pin_memory=True,\n",
    "    img_size=(32, 32, 1),\n",
    "    num_workers=12,\n",
    "    num_classes=10,\n",
    "    reproduce_confidnet_splits=True,\n",
    "    augmentations={\n",
    "        \"train\": {\n",
    "            \"to_tensor\": None,\n",
    "            \"random_crop\": [32, 4],\n",
    "            \"normalize\": [[0.5], [0.5]],\n",
    "        },\n",
    "        \"val\": {\n",
    "            \"to_tensor\": None,\n",
    "            \"resize\": 32,\n",
    "            \"normalize\": [[0.5], [0.5]],\n",
    "        },\n",
    "        \"test\": {\n",
    "            \"to_tensor\": None,\n",
    "            \"resize\": 32,\n",
    "            \"normalize\": [[0.5], [0.5]],\n",
    "        },\n",
    "    },\n",
    ")\n",
    "config.eval.query_studies.iid_study = \"mydataset\"\n",
    "config.eval.query_studies.noise_study = []\n",
    "config.eval.query_studies.in_class_study = [\"svhn\"]\n",
    "config.eval.query_studies.new_class_study = [\"cifar10\"]\n",
    "print(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2705e3a2-fd93-4b22-8e72-15b95b96dd2d",
   "metadata": {},
   "source": [
    "Let's take a look at the data loader this configuration will create."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e8edef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "datamodule = FDShiftsDataLoader(config)\n",
    "\n",
    "datamodule.setup()\n",
    "datamodule.prepare_data()\n",
    "\n",
    "x, y = next(iter(datamodule.train_dataloader()))\n",
    "\n",
    "\n",
    "def tensor_to_image(t: torch.Tensor):\n",
    "    return t.cpu().numpy().transpose(1, 2, 0)\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.imshow(tensor_to_image(x[0]))\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1e858b6-e7f4-4ccd-b74a-dc5b0c872e1a",
   "metadata": {},
   "source": [
    "### Training and Testing on the New Dataset\n",
    "We can now train and test on this dataset and recompute our results table. It will now display both SVHN and our new dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aad43968-1696-4cc2-8daf-06837b539781",
   "metadata": {},
   "outputs": [],
   "source": [
    "name = \"my_first_experiment\"\n",
    "config = config.update_experiment(name)\n",
    "\n",
    "print(config)\n",
    "\n",
    "train(config, RichProgressBar())\n",
    "test(config)\n",
    "\n",
    "data = pd.concat(\n",
    "    [\n",
    "        reporting.load_file(p, experiment_override=str(p.parent.parent.parent.stem))\n",
    "        for p in Path(os.getenv(\"EXPERIMENT_ROOT_DIR\")).glob(\"**/test_results/*.csv\")\n",
    "    ]\n",
    ")\n",
    "# data = data.assign(experiment=config.exp.group_name)\n",
    "data = data.assign(study=data.experiment + \"_\" + data.study)\n",
    "data = reporting.assign_hparams_from_names(data)\n",
    "\n",
    "# data = reporting.filter_unused(data)\n",
    "data = reporting.rename_confids(data)\n",
    "data = reporting.rename_studies(data)\n",
    "data = reporting.tables.aggregate_over_runs(data)\n",
    "data = reporting.str_format_metrics(data)\n",
    "\n",
    "results_table = reporting.tables.build_results_table(\n",
    "    data=data, metric=\"aurc\", original_mode=False, paper_filter=False\n",
    ")\n",
    "results_table"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "095b4baf",
   "metadata": {
    "lines_to_next_cell": 0
   },
   "source": [
    "## Adding a New Softmax Based Confidence Scoring Function\n",
    "\n",
    "We can also add a new softmax-based confidence scoring function, we just have to tell FD-Shifts about it. Afterwards we need to rerun the analysis for both of our experiments and update the results table."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e94fe885",
   "metadata": {},
   "outputs": [],
   "source": [
    "@analysis.confid_scores.register_confid_func(\"my_csf\")\n",
    "def my_fancy_csf(softmax):\n",
    "    return 1 - np.min(softmax, axis=1)\n",
    "\n",
    "\n",
    "config.eval.confidence_measures.test.append(\"my_csf\")\n",
    "\n",
    "analysis.main(\n",
    "    in_path=config.test.dir,\n",
    "    out_path=config.test.dir,\n",
    "    query_studies=config.eval.query_studies,\n",
    "    add_val_tuning=config.eval.val_tuning,\n",
    "    threshold_plot_confid=None,\n",
    "    cf=config,\n",
    ")\n",
    "\n",
    "config_svhn.eval.confidence_measures.test.append(\"my_csf\")\n",
    "\n",
    "analysis.main(\n",
    "    in_path=config_svhn.test.dir,\n",
    "    out_path=config_svhn.test.dir,\n",
    "    query_studies=config_svhn.eval.query_studies,\n",
    "    add_val_tuning=config_svhn.eval.val_tuning,\n",
    "    threshold_plot_confid=None,\n",
    "    cf=config_svhn,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ecc3a87-13f7-46e3-897b-58b7aa195c2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.concat(\n",
    "    [\n",
    "        reporting.load_file(p, experiment_override=str(p.parent.parent.parent.stem))\n",
    "        for p in Path(os.getenv(\"EXPERIMENT_ROOT_DIR\")).glob(\"**/test_results/*.csv\")\n",
    "    ]\n",
    ")\n",
    "# data = data.assign(experiment=config.exp.group_name)\n",
    "data = data.assign(study=data.experiment + \"_\" + data.study)\n",
    "data = reporting.assign_hparams_from_names(data)\n",
    "\n",
    "# data = reporting.filter_unused(data)\n",
    "data = reporting.rename_confids(data)\n",
    "data = reporting.rename_studies(data)\n",
    "data = reporting.tables.aggregate_over_runs(data)\n",
    "data = reporting.str_format_metrics(data)\n",
    "\n",
    "results_table = reporting.tables.build_results_table(\n",
    "    data=data, metric=\"aurc\", original_mode=False, paper_filter=False\n",
    ")\n",
    "results_table"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  },
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "py310"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
